# Libraries:
import torch
import pandas as pd
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm

script_dir = Path(__file__).resolve().parent

tqdm.pandas()

df = pd.read_feather(script_dir / "../data/source/manifesto_corpus.feather")

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")


class MyModel(
    nn.Module,
    PyTorchModelHubMixin,
    # optionally, you can add metadata which gets pushed to the model card
    # repo_url="your-repo-url",
    pipeline_tag="text-classification",
    license="mit",
):
    def __init__(self, bert_model, moral_label=2):
        super(MyModel, self).__init__()
        self.bert = bert_model
        bert_dim = 768
        self.invariant_trans = nn.Linear(768, 768)
        self.moral_classification = nn.Sequential(
            nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, moral_label)
        )

    def forward(self, input_ids, token_type_ids, attention_mask):
        pooled_output = self.bert(
            input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        ).last_hidden_state[:, 0, :]

        pooled_output = self.invariant_trans(pooled_output)

        logits = self.moral_classification(pooled_output)

        return logits


def preprocessing(input_text, tokenizer, device):
    """
    Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields:
    - input_ids: list of token ids
    - token_type_ids: list of token type ids
    - attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True).
    """
    encoded = tokenizer.encode_plus(
        input_text,
        add_special_tokens=True,
        max_length=150,
        padding="max_length",
        return_attention_mask=True,
        return_token_type_ids=True,  # Add this line
        return_tensors="pt",
        truncation=True,
    )
    return {k: v.to(device) for k, v in encoded.items()}


# the list of Moral (MFT) values
mft_values = [
    "care",
    "harm",
    "fairness",
    "cheating",
    "loyalty",
    "betrayal",
    "authority",
    "subversion",
    "purity",
    "degradation",
]

# Process one model at a time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def apply_a_moralbert_model(sentence: str, model):
    inputs = preprocessing(sentence, tokenizer, device)
    with torch.no_grad():
        output = model(**inputs)
        score = F.softmax(output, dim=1)[0, 1].item()  # second class score
    return score


for mft in mft_values:
    print(f"Processing MFT: {mft}")

    # Load and move model to device
    model = MyModel.from_pretrained(
        f"vjosap/moralBERT-predict-{mft}-in-text",
        bert_model=AutoModel.from_pretrained("bert-base-uncased"),
    ).to(device)
    # model.eval()

    # Prepare a list to hold this MFT's scores
    df[mft] = df.progress_apply(
        lambda row: apply_a_moralbert_model(sentence=row["text_en"], model=model),
        axis=1,
    )


Path(script_dir / "../data/moralbert/").mkdir(parents=True, exist_ok=True)

df[["id_for_project", "text_en", *mft_values]].to_feather(
    script_dir / "../data/moralbert/moralbert_scores.feather"
)
